Brain tumor detector¶
Nutshell¶
In this project I build a program that detects and localizes cancer from images of human brains, as explained on the course Modern Artificial Intelligence, lectured by Dr. Ryan Ahmed, Ph.D. MBA.
I will train two models which will
- classify the images either containing cancer tumor or not
- localizes the tumor within the brain
Introduction to the Brain Tumor Detection¶
Deep learning has proven to be as good and even better than humans in detecting diseases from X-rays, MRI scans and CT scans. there is huge potential in using AI to speed up and improve the accuracy of diagnosis. This project will use the labeled dataset from https://www.kaggle.com/datasets/mateuszbuda/lgg-mri-segmentation which consists of 3929 Brain MRI scans and the tumor location. The final pipeline has a two step process where
- A Resnet deep learning classifier model will classify the input images into two groups: tumor detected and tumor not detected.
- For the images, where tumor was detected, a second step is performed, where a ResUNet segmentation model detects the tumor location on the pixel level.
Image segmentation¶
Image segmentation extracts information from images on the level of pixels. It is used for object recognition and localization in applications like medical imaging and self-driving cars. Image segmentation produces a pixel-wise mask of the image with deep learning approaches using common architectures such as CNN, FNNs and Deep Encoders-Decoders.
With Unet, the input and the output have the same size so the size of the images is preserved. In contrast to the CNN image classification, where the image is converted to a vector and the entire image is classified as a class label, the Unet performs classification on pixel level. Unet formulates a loss function for every pixel and then a softmax function is applied to every pixel. In other words, the segmentation problem is solved as a classification problem.
Looking into the data¶
We have a csv file that contains the patient IDs, the locations of the images, their masks and indicator if there is a tumor in the image (1 - tumor, 0 - healthy). There are 1373 images with tumors and 2556 healthy brain images. Thus, the dataset is imbalanced.
<class 'pandas.core.frame.DataFrame'> RangeIndex: 3929 entries, 0 to 3928 Data columns (total 4 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 patient_id 3929 non-null object 1 image_path 3929 non-null object 2 mask_path 3929 non-null object 3 mask 3929 non-null int64 dtypes: int64(1), object(3) memory usage: 122.9+ KB
| patient_id | image_path | mask_path | mask | |
|---|---|---|---|---|
| 0 | TCGA_CS_5395_19981004 | TCGA_CS_5395_19981004/TCGA_CS_5395_19981004_1.tif | TCGA_CS_5395_19981004/TCGA_CS_5395_19981004_1_... | 0 |
| 1 | TCGA_CS_5395_19981004 | TCGA_CS_4944_20010208/TCGA_CS_4944_20010208_1.tif | TCGA_CS_4944_20010208/TCGA_CS_4944_20010208_1_... | 0 |
| 2 | TCGA_CS_5395_19981004 | TCGA_CS_4941_19960909/TCGA_CS_4941_19960909_1.tif | TCGA_CS_4941_19960909/TCGA_CS_4941_19960909_1_... | 0 |
| 3 | TCGA_CS_5395_19981004 | TCGA_CS_4943_20000902/TCGA_CS_4943_20000902_1.tif | TCGA_CS_4943_20000902/TCGA_CS_4943_20000902_1_... | 0 |
| 4 | TCGA_CS_5395_19981004 | TCGA_CS_5396_20010302/TCGA_CS_5396_20010302_1.tif | TCGA_CS_5396_20010302/TCGA_CS_5396_20010302_1_... | 0 |
Visualisation of the datasets¶
Below is an exmaple of an MRI image and the matching mask. This example has a small tumor. In images where no tumor is present, the mask will be complety black.
Below are visualisations from 6 MRIs and their overlayed masks in rose color to get a sense of the data that I will be using in this project.
Convolutional neural networks (CNNs)¶
- The first CNN layers are used to extract high level general features
- The last couple of layers will perform classification
- Locla respective fields scan the image first searching for simple shapes such as edges and lines
- The edges are picked up by the subsequent layer to form more complex features
A good visualisation of the feature extraction with convolutions can be found at https://setosa.io/ev/image-kernels/
ResNet (Residual Network)¶
- As CNNs grow deeper, vanishing gradients negatively imapct the network performance. Vanishing gradient occurs when the gradient is backpropagated to earlier layers which results in a very small gradient.
- ResNets "skip connection" feature can allow training of 152 layers wihtout vanishing gradient problems
- ResNet adds "identity mapping on top of the CNN
- ResNet deep network is trained with ImageNet, which contains 11 million images and 11 000 categories
ResNet paper (He etal, 2015): https://arxiv.org/pdf/1512.03385
As seen in the Figure 6. from the Resnet paper, the ResNet architectures overcome the training challenges from deep networks compared ot the plain networks. ResNet-152 achieved 3.58% error rate on the ImageNet dataset. This is better than human performance.
Siddarth Das has made agreat comparison of CNN architecture performances, you can check it out here: https://medium.com/analytics-vidhya/cnns-architectures-lenet-alexnet-vgg-googlenet-resnet-and-more-666091488df5
Transfer learning¶
Transfer learning retrains a network that has been trained to perform a specific task to use it in a similar task. The use of a pretrained model can drastically reduce the computational time and the amount of training data required, compared to starting from scratch. It can be compared to a salsa dancer starting to learn bachata; he/she will probably do a lot better than a person who has never danced before.
There are two main strategies in transfer learning:
- Freeze the trained CNN network weights from the first layers and the train newly added dense layers. The new layers are initialized with random weights.
- Retrain the entire CNN network while setting the learning rate to be very small. With too large learning rate the already trained weights might be changed too dramatically.
In this project I will use the approach 1.
Transfer learning has it's own challenges:
- Negative Transfer: the source task/domain is “close enough to look useful” but actually pushes the model in the wrong direction, hurting performance compared to training from scratch. This occurs when the features of old and new tasks are not related.
- Which layers to transfer / freeze: deciding what to reuse vs retrain is nontrivial; freezing too much can underfit, unfreezing too much can overfit or destabilize training.
- Representation misalignment: even if tasks are related, the internal features might not separate target classes well, especially when target cues differ (e.g., medical imaging vs natural images).
- Transfer bounds: Measuring the amount of knowledge transfered is crucial to ensure model quality and robustness. It is worth considering, how to quantify this, and it is a subject of ongoing research.
This is a great resource for transfer learning from Dipanjan Sarkar: https://towardsdatascience.com/a-comprehensive-hands-on-guide-to-transfer-learning-with-real-world-applications-in-deep-learning-212bf3b2f27a/
ResUNet¶
I will use ResUNet in the second part for the segmentation of the tumors.
- ResUNet aschitecture combines UNet backbone architecture and residiual blocks
- The Unet architecture uses Fully Convolutional Networks (FCN) and is adapted to perform well on segmentation tasks
- ResUNet has three parts:
- Encoder or contracting path
- Bottleneck
- Decoder or expansive path
The contraction path consists of several contraction blocks, which pass their input through res-blocks followed by 2x2 max-pooling. Feature maps after each block doubles, which helps the model learn complex features effectively.
The bottleneck part takes the input and then passes through a resblock, followed by 2x2 up-sampling convolution layers.
The decoder blocks take the up-sampled input from the previous layer and concatenates with the corresponding output features from the res-blocks in the contraction path. This is then passed throuhg a resblock. This ensures that the features learned while contracting are used while reconstructing the image.
The final exapnsion layer outuput is passed through 1x1 convolution layer to produce the desired output with the same size as the input.
The original paper that introduced ResUNet: https://arxiv.org/pdf/1904.00592
Part 1: Training a classifier model to detect if tumor exists or not¶
I use the flow_from_dataframe for training. Batch size = 16 class mode = categorical
# @title Preparing image generators
train_generator = datagen.flow_from_dataframe(
dataframe = train,
directory = './',
x_col = 'image_path',
y_col = 'mask',
subset = 'training',
batch_size =16,
shuffle = True,
class_mode = 'categorical',
target_size = (256, 256)
)
valid_generator = datagen.flow_from_dataframe(
dataframe = train,
directory = './',
x_col = 'image_path',
y_col = 'mask',
subset = 'validation',
batch_size = 16,
shuffle = True,
class_mode = 'categorical',
target_size = (256, 256)
)
#create a data generator for test images
#no need for splitting again because here we use the "test" data set
test_datagen = ImageDataGenerator(rescale=1./255)
test_generator = datagen.flow_from_dataframe(
dataframe = test,
directory = './',
x_col = 'image_path',
y_col = 'mask',
batch_size = 16,
shuffle = False,
class_mode = 'categorical',
target_size = (256, 256)
)
Found 2839 validated image filenames belonging to 2 classes. Found 500 validated image filenames belonging to 2 classes. Found 590 validated image filenames belonging to 2 classes.
Below is the architecture of the ResNet50 model. For the transfer learning, all of these layers will be set to trainable = False to stop the weights from changing. The last layers in purple are the added layers which will be trained.
# @title Retireve ResNet50 base model
#Input tensror 256 x 256 x 3
basemodel = ResNet50(weights = 'imagenet', include_top = False,
input_tensor = Input(shape = (256, 256, 3)))
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/resnet/resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5 94765736/94765736 ━━━━━━━━━━━━━━━━━━━━ 4s 0us/step
# Add classification head to the base model
headmodel = basemodel.output
headmodel = AveragePooling2D(pool_size = (4, 4))(headmodel)
headmodel = Flatten(name = 'flatten')(headmodel)
headmodel = Dense(256, activation = 'relu')(headmodel)
headmodel = Dropout(0.3)(headmodel)
headmodel = Dense(2, activation = 'softmax')(headmodel)
fullmodel = Model(inputs = basemodel.input, outputs = headmodel)
# compile the model
fullmodel.compile(loss = 'categorical_crossentropy', optimizer='adam',
metrics=["accuracy"])
# use the early stopping to exit training
earlystopping = EarlyStopping(monitor='val_loss', mode='min', verbose = 1,
patience = 20)
# save the best model with least validation loss
checkpointer = ModelCheckpoint(filepath='classifier-resnet-weights.keras',
verbose=1, save_best_only=True)
if train_model:
history = fullmodel.fit(train_generator,
steps_per_epoch = train_generator.n // train_generator.batch_size,
epochs=25,
validation_data=valid_generator,
validation_steps= valid_generator.n // valid_generator.batch_size,
callbacks=[checkpointer, earlystopping])
/usr/local/lib/python3.12/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
Epoch 1/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 11s/step - accuracy: 0.7406 - loss: 1.0108 Epoch 1: val_loss improved from inf to 5.11311, saving model to classifier-resnet-weights.keras 177/177 ━━━━━━━━━━━━━━━━━━━━ 2272s 12s/step - accuracy: 0.7409 - loss: 1.0084 - val_accuracy: 0.6492 - val_loss: 5.1131 Epoch 2/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 46s 265ms/step - accuracy: 1.0000 - loss: 0.1226
/usr/local/lib/python3.12/dist-packages/keras/src/trainers/epoch_iterator.py:116: UserWarning: Your input ran out of data; interrupting training. Make sure that your dataset or generator can generate at least `steps_per_epoch * epochs` batches. You may need to use the `.repeat()` function when building your dataset.
Epoch 2: val_loss did not improve from 5.11311 177/177 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 1.0000 - loss: 0.1226 - val_accuracy: 0.6492 - val_loss: 5.1552 Epoch 3/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 226ms/step - accuracy: 0.8663 - loss: 0.3271 Epoch 3: val_loss improved from 5.11311 to 0.65219, saving model to classifier-resnet-weights.keras 177/177 ━━━━━━━━━━━━━━━━━━━━ 45s 256ms/step - accuracy: 0.8663 - loss: 0.3271 - val_accuracy: 0.6472 - val_loss: 0.6522 Epoch 4/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 46s 267ms/step - accuracy: 0.9375 - loss: 0.2529 Epoch 4: val_loss improved from 0.65219 to 0.65116, saving model to classifier-resnet-weights.keras 177/177 ━━━━━━━━━━━━━━━━━━━━ 5s 26ms/step - accuracy: 0.9375 - loss: 0.2529 - val_accuracy: 0.6472 - val_loss: 0.6512 Epoch 5/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 227ms/step - accuracy: 0.8795 - loss: 0.3138 Epoch 5: val_loss did not improve from 0.65116 177/177 ━━━━━━━━━━━━━━━━━━━━ 43s 240ms/step - accuracy: 0.8795 - loss: 0.3137 - val_accuracy: 0.6512 - val_loss: 2.4637 Epoch 6/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 31s 181ms/step - accuracy: 0.8750 - loss: 0.5172 Epoch 6: val_loss did not improve from 0.65116 177/177 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.8750 - loss: 0.5172 - val_accuracy: 0.6492 - val_loss: 2.8635 Epoch 7/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 194ms/step - accuracy: 0.8860 - loss: 0.2790 Epoch 7: val_loss did not improve from 0.65116 177/177 ━━━━━━━━━━━━━━━━━━━━ 37s 208ms/step - accuracy: 0.8860 - loss: 0.2788 - val_accuracy: 0.6532 - val_loss: 0.6581 Epoch 8/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 31s 181ms/step - accuracy: 0.9375 - loss: 0.2686 Epoch 8: val_loss did not improve from 0.65116 177/177 ━━━━━━━━━━━━━━━━━━━━ 3s 15ms/step - accuracy: 0.9375 - loss: 0.2686 - val_accuracy: 0.6552 - val_loss: 0.6583 Epoch 9/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 194ms/step - accuracy: 0.9208 - loss: 0.1985 Epoch 9: val_loss improved from 0.65116 to 0.58625, saving model to classifier-resnet-weights.keras 177/177 ━━━━━━━━━━━━━━━━━━━━ 39s 220ms/step - accuracy: 0.9208 - loss: 0.1986 - val_accuracy: 0.7097 - val_loss: 0.5862 Epoch 10/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 44s 251ms/step - accuracy: 0.8750 - loss: 0.6005 Epoch 10: val_loss improved from 0.58625 to 0.58493, saving model to classifier-resnet-weights.keras 177/177 ━━━━━━━━━━━━━━━━━━━━ 5s 28ms/step - accuracy: 0.8750 - loss: 0.6005 - val_accuracy: 0.7016 - val_loss: 0.5849 Epoch 11/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 233ms/step - accuracy: 0.9280 - loss: 0.1870 Epoch 11: val_loss improved from 0.58493 to 0.37330, saving model to classifier-resnet-weights.keras 177/177 ━━━━━━━━━━━━━━━━━━━━ 46s 258ms/step - accuracy: 0.9280 - loss: 0.1871 - val_accuracy: 0.8367 - val_loss: 0.3733 Epoch 12/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 45s 260ms/step - accuracy: 0.9375 - loss: 0.1971 Epoch 12: val_loss did not improve from 0.37330 177/177 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.9375 - loss: 0.1971 - val_accuracy: 0.8448 - val_loss: 0.3847 Epoch 13/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 227ms/step - accuracy: 0.9334 - loss: 0.1942 Epoch 13: val_loss improved from 0.37330 to 0.17680, saving model to classifier-resnet-weights.keras 177/177 ━━━━━━━━━━━━━━━━━━━━ 45s 254ms/step - accuracy: 0.9334 - loss: 0.1942 - val_accuracy: 0.9415 - val_loss: 0.1768 Epoch 14/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 44s 252ms/step - accuracy: 1.0000 - loss: 0.0706 Epoch 14: val_loss improved from 0.17680 to 0.17534, saving model to classifier-resnet-weights.keras 177/177 ━━━━━━━━━━━━━━━━━━━━ 5s 28ms/step - accuracy: 1.0000 - loss: 0.0706 - val_accuracy: 0.9415 - val_loss: 0.1753 Epoch 15/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 229ms/step - accuracy: 0.9482 - loss: 0.1595 Epoch 15: val_loss did not improve from 0.17534 177/177 ━━━━━━━━━━━━━━━━━━━━ 44s 243ms/step - accuracy: 0.9481 - loss: 0.1595 - val_accuracy: 0.8690 - val_loss: 0.4619 Epoch 16/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 33s 191ms/step - accuracy: 0.9375 - loss: 0.1972 Epoch 16: val_loss did not improve from 0.17534 177/177 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.9375 - loss: 0.1972 - val_accuracy: 0.8488 - val_loss: 0.5329 Epoch 17/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 193ms/step - accuracy: 0.9404 - loss: 0.1725 Epoch 17: val_loss did not improve from 0.17534 177/177 ━━━━━━━━━━━━━━━━━━━━ 37s 207ms/step - accuracy: 0.9404 - loss: 0.1725 - val_accuracy: 0.8629 - val_loss: 0.4133 Epoch 18/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 32s 186ms/step - accuracy: 0.8750 - loss: 0.3450 Epoch 18: val_loss did not improve from 0.17534 177/177 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.8750 - loss: 0.3450 - val_accuracy: 0.8770 - val_loss: 0.3382 Epoch 19/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 195ms/step - accuracy: 0.9301 - loss: 0.1879 Epoch 19: val_loss did not improve from 0.17534 177/177 ━━━━━━━━━━━━━━━━━━━━ 37s 208ms/step - accuracy: 0.9301 - loss: 0.1878 - val_accuracy: 0.9234 - val_loss: 0.2028 Epoch 20/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 32s 183ms/step - accuracy: 0.9375 - loss: 0.1574 Epoch 20: val_loss did not improve from 0.17534 177/177 ━━━━━━━━━━━━━━━━━━━━ 3s 16ms/step - accuracy: 0.9375 - loss: 0.1574 - val_accuracy: 0.9254 - val_loss: 0.1922 Epoch 21/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 194ms/step - accuracy: 0.9541 - loss: 0.1313 Epoch 21: val_loss improved from 0.17534 to 0.16246, saving model to classifier-resnet-weights.keras 177/177 ━━━━━━━━━━━━━━━━━━━━ 39s 219ms/step - accuracy: 0.9541 - loss: 0.1313 - val_accuracy: 0.9516 - val_loss: 0.1625 Epoch 22/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 47s 267ms/step - accuracy: 1.0000 - loss: 0.0322 Epoch 22: val_loss did not improve from 0.16246 177/177 ━━━━━━━━━━━━━━━━━━━━ 3s 17ms/step - accuracy: 1.0000 - loss: 0.0322 - val_accuracy: 0.9496 - val_loss: 0.1651 Epoch 23/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 235ms/step - accuracy: 0.9626 - loss: 0.0960 Epoch 23: val_loss did not improve from 0.16246 177/177 ━━━━━━━━━━━━━━━━━━━━ 44s 249ms/step - accuracy: 0.9626 - loss: 0.0961 - val_accuracy: 0.9415 - val_loss: 0.1676 Epoch 24/25 1/177 ━━━━━━━━━━━━━━━━━━━━ 32s 184ms/step - accuracy: 0.9375 - loss: 0.1396 Epoch 24: val_loss did not improve from 0.16246 177/177 ━━━━━━━━━━━━━━━━━━━━ 3s 13ms/step - accuracy: 0.9375 - loss: 0.1396 - val_accuracy: 0.9456 - val_loss: 0.1633 Epoch 25/25 177/177 ━━━━━━━━━━━━━━━━━━━━ 0s 195ms/step - accuracy: 0.9668 - loss: 0.0948 Epoch 25: val_loss did not improve from 0.16246 177/177 ━━━━━━━━━━━━━━━━━━━━ 37s 208ms/step - accuracy: 0.9667 - loss: 0.0949 - val_accuracy: 0.8427 - val_loss: 0.8755
Assess trained model performance¶
The model accuracy is 0.97
from sklearn.metrics import classification_report
print(classification_report(y_true, y_pred, labels = [0,1]))
precision recall f1-score support
0 0.98 0.98 0.98 383
1 0.97 0.96 0.96 207
micro avg 0.97 0.97 0.97 590
macro avg 0.97 0.97 0.97 590
weighted avg 0.97 0.97 0.97 590
Part 2: Building a segmentation model to localise tumors¶
def resblock(X, f):
# make a copy of input
X_copy = X
# main path
# Read more about he_normal: https://medium.com/@prateekvishnu/xavier-and-he-normal-he-et-al-initialization-8e3d7a087528
X = Conv2D(f, kernel_size = (1,1) ,strides = (1,1),kernel_initializer ='he_normal')(X)
X = BatchNormalization()(X)
X = Activation('relu')(X)
X = Conv2D(f, kernel_size = (3,3), strides =(1,1), padding = 'same', kernel_initializer ='he_normal')(X)
X = BatchNormalization()(X)
# Short path
# Read more here: https://towardsdatascience.com/understanding-and-coding-a-resnet-in-keras-446d7ff84d33
X_copy = Conv2D(f, kernel_size = (1,1), strides =(1,1), kernel_initializer ='he_normal')(X_copy)
X_copy = BatchNormalization()(X_copy)
# Adding the output from main path and short path together
X = Add()([X,X_copy])
X = Activation('relu')(X)
return X
# function to upscale and concatenate the values passsed
def upsample_concat(x, skip):
x = UpSampling2D((2,2))(x)
merge = Concatenate()([x, skip])
return merge
input_shape = (256,256,3)
# Input tensor shape
X_input = Input(input_shape)
# Stage 1
conv1_in = Conv2D(16,3,activation= 'relu', padding = 'same', kernel_initializer ='he_normal')(X_input)
conv1_in = BatchNormalization()(conv1_in)
conv1_in = Conv2D(16,3,activation= 'relu', padding = 'same', kernel_initializer ='he_normal')(conv1_in)
conv1_in = BatchNormalization()(conv1_in)
pool_1 = MaxPool2D(pool_size = (2,2))(conv1_in)
# Stage 2
conv2_in = resblock(pool_1, 32)
pool_2 = MaxPool2D(pool_size = (2,2))(conv2_in)
# Stage 3
conv3_in = resblock(pool_2, 64)
pool_3 = MaxPool2D(pool_size = (2,2))(conv3_in)
# Stage 4
conv4_in = resblock(pool_3, 128)
pool_4 = MaxPool2D(pool_size = (2,2))(conv4_in)
# Stage 5 (Bottle Neck)
conv5_in = resblock(pool_4, 256)
# Upscale stage 1
up_1 = upsample_concat(conv5_in, conv4_in)
up_1 = resblock(up_1, 128)
# Upscale stage 2
up_2 = upsample_concat(up_1, conv3_in)
up_2 = resblock(up_2, 64)
# Upscale stage 3
up_3 = upsample_concat(up_2, conv2_in)
up_3 = resblock(up_3, 32)
# Upscale stage 4
up_4 = upsample_concat(up_3, conv1_in)
up_4 = resblock(up_4, 16)
# Final Output
output = Conv2D(1, (1,1), padding = "same", activation = "sigmoid")(up_4)
model_seg = Model(inputs = X_input, outputs = output )
# @title Training the segmentation model
# Compile the model
adam = keras.optimizers.Adam(learning_rate = 0.05, epsilon = 0.1)
model_seg.compile(optimizer = adam, loss = focal_tversky_fixed, metrics = [tversky_fixed])
# use early stopping to exit training if validation loss is not decreasing even after certain epochs (patience)
earlystopping = EarlyStopping(monitor='val_loss', mode='min', verbose=1, patience=20)
# save the best model with lower validation loss
checkpointer = ModelCheckpoint(filepath="ResUNet-weights.keras", verbose=1, save_best_only=True)
if train_model:
history = model_seg.fit(
training_generator,
epochs=25,
validation_data=validation_generator,
callbacks=[checkpointer, earlystopping],
)
Epoch 1/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 22s/step - loss: 0.9111 - tversky_fixed: 0.1166
/usr/local/lib/python3.12/dist-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py:121: UserWarning: Your `PyDataset` class should call `super().__init__(**kwargs)` in its constructor. `**kwargs` can include `workers`, `use_multiprocessing`, `max_queue_size`. Do not pass these arguments to `fit()`, as they will be ignored.
Epoch 1: val_loss improved from inf to 0.90052, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 1763s 24s/step - loss: 0.9111 - tversky_fixed: 0.1167 - val_loss: 0.9005 - val_tversky_fixed: 0.1303 Epoch 2/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 449ms/step - loss: 0.8745 - tversky_fixed: 0.1635 Epoch 2: val_loss improved from 0.90052 to 0.79715, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 40s 566ms/step - loss: 0.8741 - tversky_fixed: 0.1640 - val_loss: 0.7971 - val_tversky_fixed: 0.2605 Epoch 3/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 143ms/step - loss: 0.6143 - tversky_fixed: 0.4735 Epoch 3: val_loss improved from 0.79715 to 0.63941, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 166ms/step - loss: 0.6130 - tversky_fixed: 0.4750 - val_loss: 0.6394 - val_tversky_fixed: 0.4484 Epoch 4/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 143ms/step - loss: 0.3883 - tversky_fixed: 0.7143 Epoch 4: val_loss improved from 0.63941 to 0.44970, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 166ms/step - loss: 0.3881 - tversky_fixed: 0.7145 - val_loss: 0.4497 - val_tversky_fixed: 0.6547 Epoch 5/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 145ms/step - loss: 0.3613 - tversky_fixed: 0.7407 Epoch 5: val_loss improved from 0.44970 to 0.36883, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 164ms/step - loss: 0.3612 - tversky_fixed: 0.7408 - val_loss: 0.3688 - val_tversky_fixed: 0.7334 Epoch 6/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 152ms/step - loss: 0.3129 - tversky_fixed: 0.7862 Epoch 6: val_loss did not improve from 0.36883 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 166ms/step - loss: 0.3129 - tversky_fixed: 0.7861 - val_loss: 0.4216 - val_tversky_fixed: 0.6795 Epoch 7/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 152ms/step - loss: 0.2840 - tversky_fixed: 0.8115 Epoch 7: val_loss improved from 0.36883 to 0.30192, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 171ms/step - loss: 0.2839 - tversky_fixed: 0.8116 - val_loss: 0.3019 - val_tversky_fixed: 0.7967 Epoch 8/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 145ms/step - loss: 0.2486 - tversky_fixed: 0.8425 Epoch 8: val_loss did not improve from 0.30192 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 161ms/step - loss: 0.2487 - tversky_fixed: 0.8425 - val_loss: 0.3242 - val_tversky_fixed: 0.7760 Epoch 9/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 143ms/step - loss: 0.2491 - tversky_fixed: 0.8413 Epoch 9: val_loss improved from 0.30192 to 0.20870, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 162ms/step - loss: 0.2491 - tversky_fixed: 0.8413 - val_loss: 0.2087 - val_tversky_fixed: 0.8757 Epoch 10/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 152ms/step - loss: 0.2133 - tversky_fixed: 0.8716 Epoch 10: val_loss did not improve from 0.20870 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 165ms/step - loss: 0.2134 - tversky_fixed: 0.8715 - val_loss: 0.5172 - val_tversky_fixed: 0.5828 Epoch 11/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 148ms/step - loss: 0.2107 - tversky_fixed: 0.8738 Epoch 11: val_loss did not improve from 0.20870 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 162ms/step - loss: 0.2107 - tversky_fixed: 0.8738 - val_loss: 0.2375 - val_tversky_fixed: 0.8524 Epoch 12/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 147ms/step - loss: 0.1951 - tversky_fixed: 0.8861 Epoch 12: val_loss did not improve from 0.20870 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 160ms/step - loss: 0.1950 - tversky_fixed: 0.8861 - val_loss: 0.2418 - val_tversky_fixed: 0.8490 Epoch 13/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 149ms/step - loss: 0.1867 - tversky_fixed: 0.8925 Epoch 13: val_loss did not improve from 0.20870 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 163ms/step - loss: 0.1867 - tversky_fixed: 0.8925 - val_loss: 0.2510 - val_tversky_fixed: 0.8405 Epoch 14/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 148ms/step - loss: 0.1935 - tversky_fixed: 0.8873 Epoch 14: val_loss improved from 0.20870 to 0.18447, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 166ms/step - loss: 0.1934 - tversky_fixed: 0.8874 - val_loss: 0.1845 - val_tversky_fixed: 0.8946 Epoch 15/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 147ms/step - loss: 0.1678 - tversky_fixed: 0.9069 Epoch 15: val_loss improved from 0.18447 to 0.17346, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 169ms/step - loss: 0.1679 - tversky_fixed: 0.9068 - val_loss: 0.1735 - val_tversky_fixed: 0.9031 Epoch 16/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 152ms/step - loss: 0.1673 - tversky_fixed: 0.9073 Epoch 16: val_loss did not improve from 0.17346 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 164ms/step - loss: 0.1674 - tversky_fixed: 0.9073 - val_loss: 0.1907 - val_tversky_fixed: 0.8900 Epoch 17/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 145ms/step - loss: 0.1665 - tversky_fixed: 0.9079 Epoch 17: val_loss did not improve from 0.17346 72/72 ━━━━━━━━━━━━━━━━━━━━ 11s 159ms/step - loss: 0.1665 - tversky_fixed: 0.9079 - val_loss: 0.1871 - val_tversky_fixed: 0.8927 Epoch 18/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 145ms/step - loss: 0.1600 - tversky_fixed: 0.9126 Epoch 18: val_loss did not improve from 0.17346 72/72 ━━━━━━━━━━━━━━━━━━━━ 11s 157ms/step - loss: 0.1600 - tversky_fixed: 0.9126 - val_loss: 0.1755 - val_tversky_fixed: 0.9013 Epoch 19/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 146ms/step - loss: 0.1584 - tversky_fixed: 0.9139 Epoch 19: val_loss did not improve from 0.17346 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 160ms/step - loss: 0.1584 - tversky_fixed: 0.9139 - val_loss: 0.1815 - val_tversky_fixed: 0.8962 Epoch 20/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 153ms/step - loss: 0.1483 - tversky_fixed: 0.9206 Epoch 20: val_loss did not improve from 0.17346 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 167ms/step - loss: 0.1483 - tversky_fixed: 0.9206 - val_loss: 0.1904 - val_tversky_fixed: 0.8904 Epoch 21/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 145ms/step - loss: 0.1433 - tversky_fixed: 0.9246 Epoch 21: val_loss did not improve from 0.17346 72/72 ━━━━━━━━━━━━━━━━━━━━ 11s 158ms/step - loss: 0.1433 - tversky_fixed: 0.9246 - val_loss: 0.1797 - val_tversky_fixed: 0.8978 Epoch 22/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 146ms/step - loss: 0.1313 - tversky_fixed: 0.9329 Epoch 22: val_loss improved from 0.17346 to 0.17340, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 163ms/step - loss: 0.1314 - tversky_fixed: 0.9329 - val_loss: 0.1734 - val_tversky_fixed: 0.9026 Epoch 23/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 145ms/step - loss: 0.1351 - tversky_fixed: 0.9304 Epoch 23: val_loss improved from 0.17340 to 0.16084, saving model to ResUNet-weights.keras 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 166ms/step - loss: 0.1350 - tversky_fixed: 0.9304 - val_loss: 0.1608 - val_tversky_fixed: 0.9124 Epoch 24/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 145ms/step - loss: 0.1293 - tversky_fixed: 0.9341 Epoch 24: val_loss did not improve from 0.16084 72/72 ━━━━━━━━━━━━━━━━━━━━ 12s 161ms/step - loss: 0.1294 - tversky_fixed: 0.9341 - val_loss: 0.2011 - val_tversky_fixed: 0.8821 Epoch 25/25 72/72 ━━━━━━━━━━━━━━━━━━━━ 0s 145ms/step - loss: 0.1342 - tversky_fixed: 0.9307 Epoch 25: val_loss did not improve from 0.16084 72/72 ━━━━━━━━━━━━━━━━━━━━ 11s 158ms/step - loss: 0.1342 - tversky_fixed: 0.9307 - val_loss: 0.1900 - val_tversky_fixed: 0.8906
Assessing the trained segmentation model performance¶
To assess the performance, the predicted and actual masks of 10 test cases are printed below. The model has not seen this data before.